
if __name__ == '__main__':
    paddle.fleat()
    batch_size = 1
    accumulate_batchs_num = 20
    total_epoch = 100000
    lr = 0.0001
    pred_batch_size = 1

    model = SmilesTransformer(emb_dim=4096, n_head=512,
                                num_encoder_layers=128,
                                num_decoder_layers=128,
                                dim_feedforward=512,
                                max_len=max_len,
                                encoder_lib_size=dict_len,
                                output_lib_size=dict_len)

    paddle.summary(model, ((1, max_len), (1, max_len)), dtypes='int64')

    # 初次组网训练
    train(model,pre_train=False)

    # 加载预训练模型继续训练
    # train(model,pre_train=True)

    # 输出结果
    predict(model, predict=True)